package dessert
package passes

import firrtl._
import firrtl.annotations._
import mdf.macrolib._
import mdf.macrolib.Utils.writeMDFToString
import java.io.{File, FileWriter}

case class MemConf(
  name: String,
  depth: BigInt,
  width: BigInt,
  readers: Seq[String],
  writers: Seq[String],
  readwriters: Seq[String],
  maskGran: BigInt)

object MemConfReader {
  sealed trait ConfField
  case object Name extends ConfField
  case object Depth extends ConfField
  case object Width extends ConfField
  case object Ports extends ConfField
  case object MaskGran extends ConfField
  type ConfFieldMap = Map[ConfField, String]
  // Read a conf file generated by [[firrtl.passes.ReplSeqMems]]
  def apply(conf: java.io.File): Seq[MemConf] = {
    def parse(map: ConfFieldMap, list: List[String]): ConfFieldMap = list match {
      case Nil => map
      case "name" :: value :: tail => parse(map + (Name -> value), tail)
      case "depth" :: value :: tail => parse(map + (Depth -> value), tail)
      case "width" :: value :: tail => parse(map + (Width -> value), tail)
      case "ports" :: value :: tail => parse(map + (Ports -> value), tail)
      case "mask_gran" :: value :: tail => parse(map + (MaskGran -> value), tail)
      case field :: tail => firrtl.Utils.error(s"Unknown field $field")
    }
    io.Source.fromFile(conf).getLines.toSeq map { line =>
      val map = parse(Map[ConfField, String](), (line split " ").toList)
      val ports = map(Ports) split ","
      MemConf(map(Name), BigInt(map(Depth)), BigInt(map(Width)),
        ports filter (_ == "read"),
        ports filter (p => p == "write" || p == "mwrite"),
        ports filter (p => p == "rw" || p == "mrw"),
        map get MaskGran map (BigInt(_)) getOrElse (BigInt(map(Width))))
    }
  }
}

case class ConfToJSONAnnotation(conf: File, json: File) extends NoTargetAnnotation

object ConfToJSON extends Transform {
  def inputForm = MidForm
  def outputForm = MidForm

  def toSRAMMacro(mem: MemConf): SRAMMacro = {
    val readPorts = mem.readers.zipWithIndex map { case (r, i) => MacroPort(
      PolarizedPort(s"R${i}_addr", ActiveHigh),
      PolarizedPort(s"R${i}_clk", ActiveHigh),
      output = Some(PolarizedPort(s"R${i}_data", ActiveHigh)),
      chipEnable = Some(PolarizedPort(s"R${i}_en", ActiveHigh)),
      depth = Some(mem.depth.toInt),
      width = Some(mem.width.toInt)
    )}
    val writePorts = mem.writers.zipWithIndex map { case (w, i) => MacroPort(
      PolarizedPort(s"W${i}_addr", ActiveHigh),
      PolarizedPort(s"W${i}_clk", ActiveHigh),
      input = Some(PolarizedPort(s"W${i}_data", ActiveHigh)),
      chipEnable = Some(PolarizedPort(s"W${i}_en", ActiveHigh)),
      maskPort = if (w.head == 'm') Some(PolarizedPort(s"W${i}_mask", ActiveHigh)) else None,
      maskGran = if (w.head == 'm') Some(mem.maskGran.toInt) else None,
      depth = Some(mem.depth.toInt),
      width = Some(mem.width.toInt)
    )}
    val readwritePorts = mem.readwriters.zipWithIndex map { case (rw, i) => MacroPort(
      PolarizedPort(s"RW${i}_addr", ActiveHigh),
      PolarizedPort(s"RW${i}_clk", ActiveHigh),
      input = Some(PolarizedPort(s"RW${i}_wdata", ActiveHigh)),
      output = Some(PolarizedPort(s"RW${i}_rdata", ActiveHigh)),
      chipEnable = Some(PolarizedPort(s"RW${i}_en", ActiveHigh)),
      writeEnable = Some(PolarizedPort(s"RW${i}_wmode", ActiveHigh)),
      maskPort = if (rw.head == 'm') Some(PolarizedPort(s"RW${i}_wmask", ActiveHigh)) else None,
      maskGran = if (rw.head == 'm') Some(mem.maskGran.toInt) else None,
      depth = Some(mem.depth.toInt),
      width = Some(mem.width.toInt)
    )}
    SRAMMacro(mem.name, mem.width.toInt, mem.depth.toInt, "",
              readPorts ++ writePorts ++ readwritePorts, Nil)
  }

  def execute(state: CircuitState) = {
    state.annotations foreach {
      case ConfToJSONAnnotation(conf, json) =>
        val macros = MemConfReader(conf) map toSRAMMacro
        val writer = new FileWriter(json)
        writer write writeMDFToString(macros)
        writer.close
      case _ =>
    }
    state
  }
}
